{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 初始化方法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch_geometric.nn.inits源码"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "def uniform(size, tensor):\n",
    "    if tensor is not None:\n",
    "        bound = 1.0 / math.sqrt(size)\n",
    "        tensor.data.uniform_(-bound, bound)\n",
    "\n",
    "\n",
    "def kaiming_uniform(tensor, fan, a):\n",
    "    if tensor is not None:\n",
    "        bound = math.sqrt(6 / ((1 + a**2) * fan))\n",
    "        tensor.data.uniform_(-bound, bound)\n",
    "\n",
    "\n",
    "def glorot(tensor):\n",
    "    if tensor is not None:\n",
    "        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))\n",
    "        tensor.data.uniform_(-stdv, stdv)\n",
    "\n",
    "\n",
    "def glorot_orthogonal(tensor, scale):\n",
    "    if tensor is not None:\n",
    "        torch.nn.init.orthogonal_(tensor.data)\n",
    "        scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var())\n",
    "        tensor.data *= scale.sqrt()\n",
    "\n",
    "\n",
    "def zeros(tensor):\n",
    "    if tensor is not None:\n",
    "        tensor.data.fill_(0)\n",
    "\n",
    "\n",
    "def ones(tensor):\n",
    "    if tensor is not None:\n",
    "        tensor.data.fill_(1)\n",
    "\n",
    "\n",
    "def normal(tensor, mean, std):\n",
    "    if tensor is not None:\n",
    "        tensor.data.normal_(mean, std)\n",
    "\n",
    "\n",
    "def reset(nn):\n",
    "    def _reset(item):\n",
    "        if hasattr(item, 'reset_parameters'):\n",
    "            item.reset_parameters()\n",
    "\n",
    "    if nn is not None:\n",
    "        if hasattr(nn, 'children') and len(list(nn.children())) > 0:\n",
    "            for item in nn.children():\n",
    "                _reset(item)\n",
    "        else:\n",
    "            _reset(nn)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 样例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-15T11:42:46.883963Z",
     "start_time": "2021-06-15T11:42:44.962577Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.nn.inits import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-11T07:57:30.898112Z",
     "start_time": "2021-06-11T07:57:30.892730Z"
    }
   },
   "outputs": [],
   "source": [
    "a = torch.FloatTensor([[5,4,3,2,1],[1,2,3,4,5]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-11T07:57:31.165970Z",
     "start_time": "2021-06-11T07:57:31.158362Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 1., 1., 1.],\n",
       "        [1., 1., 1., 1., 1.]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ones(a)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-11T07:57:31.408628Z",
     "start_time": "2021-06-11T07:57:31.401271Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zeros(a)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-15T11:43:24.443260Z",
     "start_time": "2021-06-15T11:43:24.425424Z"
    }
   },
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "Dimension out of range (expected to be in range of [-1, 0], but got -2)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-1d808a7dfe0a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mglorot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/song/lib/python3.6/site-packages/torch_geometric/nn/inits.py\u001b[0m in \u001b[0;36mglorot\u001b[0;34m(tensor)\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mglorot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mtensor\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m         \u001b[0mstdv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m6.0\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m         \u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniform_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mstdv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstdv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-1, 0], but got -2)"
     ]
    }
   ],
   "source": [
    "a = torch.ones((3))\n",
    "glorot(a)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "song",
   "language": "python",
   "name": "song"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
